Switch
根据条件选择输入张量。该算子根据布尔条件 condition 的值,选择 input_x 或 input_y 作为输出。该算子不区分数据类型,适用于所有数据类型。
\[\begin{split}\text{output} = \begin{cases}
\text{input\_x}, & \text{if } \text{condition} = \text{True} \\
\text{input\_y}, & \text{if } \text{condition} = \text{False}
\end{cases}\end{split}\]
该算子不复制数据,只是将输出指针指向选中的输入张量。因此,输出张量共享输入张量的数据指针和元数据。
- 输入:
input_x - 第一个输入张量(TensorC* 类型)。当 condition 为 True 时被选中。
input_y - 第二个输入张量(TensorC* 类型)。当 condition 为 False 时被选中。
condition - 条件值(bool 类型),决定选择哪个输入张量。
- 输出:
output - 输出张量指针的指针(TensorC** 类型),指向选中的输入张量。
- 支持平台:
FT78NEMT7004
备注
该算子不区分数据类型,适用于所有数据类型
算子不复制数据,输出张量共享输入张量的数据指针
输出张量的所有元数据(形状、数据类型、格式等)与选中的输入张量相同
共享存储版本:
-
void switch_s(TensorC *input_x, TensorC *input_y, TensorC **output, bool condition)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <switch.h>
4
5int main(int argc, char* argv[]) {
6 TensorC input_x;
7 TensorC input_y;
8 TensorC* output;
9
10 // 初始化 input_x
11 int x_shape[3] = {2, 3, 4};
12 memcpy(input_x.shape_, x_shape, 3 * sizeof(int));
13 input_x.shape_size_ = 3;
14 input_x.data_type_ = kNumberTypeFloat32;
15 input_x.format_ = Format_NCHW;
16 input_x.data_ = (void *)0xA0000000;
17 input_x.category_ = 0; // 非常量
18 input_x.shape_changed_ = false;
19
20 // 初始化 input_y
21 int y_shape[3] = {2, 3, 4};
22 memcpy(input_y.shape_, y_shape, 3 * sizeof(int));
23 input_y.shape_size_ = 3;
24 input_y.data_type_ = kNumberTypeFloat32;
25 input_y.format_ = Format_NCHW;
26 input_y.data_ = (void *)0xB0000000;
27 input_y.category_ = 0;
28 input_y.shape_changed_ = false;
29
30 bool condition = true; // 选择 input_x
31
32 switch_s(&input_x, &input_y, &output, condition);
33
34 return 0;
35}
私有存储版本:
-
void switch_p(TensorC *input_x, TensorC *input_y, TensorC **output, bool condition)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <switch.h>
4
5int main(int argc, char* argv[]) {
6 TensorC input_x;
7 TensorC input_y;
8 TensorC* output;
9
10 // 初始化 input_x
11 int x_shape[3] = {2, 3, 4};
12 memcpy(input_x.shape_, x_shape, 3 * sizeof(int));
13 input_x.shape_size_ = 3;
14 input_x.data_type_ = kNumberTypeFloat32;
15 input_x.format_ = Format_NCHW;
16 input_x.data_ = (void *)0x10000000;
17 input_x.category_ = 0; // 非常量
18 input_x.shape_changed_ = false;
19
20 // 初始化 input_y
21 int y_shape[3] = {2, 3, 4};
22 memcpy(input_y.shape_, y_shape, 3 * sizeof(int));
23 input_y.shape_size_ = 3;
24 input_y.data_type_ = kNumberTypeFloat32;
25 input_y.format_ = Format_NCHW;
26 input_y.data_ = (void *)0x10001000;
27 input_y.category_ = 0;
28 input_y.shape_changed_ = false;
29
30 bool condition = true; // 选择 input_x
31
32 switch_p(&input_x, &input_y, &output, condition);
33
34 return 0;
35}